# Copyright 2023 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load(
    "//:build_defs.bzl",
    "xnnpack_binary",
    "xnnpack_cxx_library",
    "xnnpack_kleidiai_defines",
    "xnnpack_test_deps_for_library",
    "xnnpack_unit_test",
)
load(
    "//:build_params.bzl",
    "xnnpack_select_if",
    "xnnpack_simd_copts_for_arch",
    "xnnpack_simd_f16_archs",
    "xnnpack_simd_f32_archs",
    "xnnpack_simd_s16_archs",
    "xnnpack_simd_s32_archs",
    "xnnpack_simd_s8_archs",
    "xnnpack_simd_u32_archs",
)

MICROKERNEL_TEST_DEPS = [
    ":next_prime",
    ":replicable_random_device",
    "//:aligned_allocator",
    "//:all_microkernels",
    "//:allocator",
    "//:common",
    "//:fp16",
    "//:isa_checks",
    "//:math",
    "//:memory",
    "//:microkernels_h",
    "//:microparams_init",
    "//:microparams",
    "//:packing",
    "//:params",
    "//:quantization",
    "//:requantization",
    "//:xnnpack_h",
]

OPERATOR_TEST_DEPS = [
    ":replicable_random_device",
    "@pthreadpool",
    "//:aligned_allocator",
    "//:allocator",
    "//:cache",
    "//:common",
    "//:fp16",
    "//:internal",
    "//:math",
    "//:microkernel_configs",
    "//:microparams",
    "//:normalization",
    "//:params",
    "//:quantization",
    "//:XNNPACK",
]

############################## Testing utilities ###############################

xnnpack_cxx_library(
    name = "replicable_random_device",
    testonly = True,
    hdrs = ["replicable_random_device.h"],
    deps = xnnpack_test_deps_for_library() + [
        "//:xnnpack_h",
    ],
)

xnnpack_cxx_library(
    name = "next_prime",
    testonly = True,
    srcs = ["next_prime.cc"],
    hdrs = ["next_prime.h"],
    deps = xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "gemm_microkernel_tester",
    testonly = True,
    srcs = ["gemm-microkernel-tester.cc"],
    hdrs = ["gemm-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library() + [
        "//:XNNPACK",
        "//:config_hdrs",
    ],
)

xnnpack_cxx_library(
    name = "unary_operator_tester",
    testonly = True,
    srcs = ["unary-operator-tester.cc"],
    hdrs = ["unary-operator-tester.h"],
    deps = OPERATOR_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "vunary_microkernel_tester",
    testonly = True,
    srcs = ["vunary-microkernel-tester.cc"],
    hdrs = ["vunary-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "vbinary_microkernel_tester",
    testonly = True,
    srcs = ["vbinary-microkernel-tester.cc"],
    hdrs = ["vbinary-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "dwconv_microkernel_tester",
    testonly = True,
    srcs = ["dwconv-microkernel-tester.cc"],
    hdrs = ["dwconv-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library() + [
        "//:microkernel_utils",
    ],
)

xnnpack_cxx_library(
    name = "vcvt_microkernel_tester",
    testonly = True,
    srcs = ["vcvt-microkernel-tester.cc"],
    hdrs = ["vcvt-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "rdsum_microkernel_tester",
    testonly = True,
    hdrs = ["rdsum-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "packq_microkernel_tester",
    testonly = True,
    srcs = ["packq-microkernel-tester.cc"],
    hdrs = ["packq-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

####################### Unit tests for microkernel lists #######################
sh_test(
    name = "microkernel_lists_test",
    size = "small",
    srcs = ["microkernel_lists_test.sh"],
    data = [
        "//:cmake_microkernel_lists",
        "//:generated_microkernel_lists",
        "//gen:bzl_microkernel_lists",
    ],
    target_compatible_with = xnnpack_select_if(
        "//build_config:linux",
        [],
        ["@platforms//:incompatible"],
    ),
)

######################### Unit tests for simd wrappers #########################
[xnnpack_unit_test(
    name = "f32_simd_" + arch + "_test",
    srcs = [
        "f32-simd-" + arch + ".cc",
    ],
    copts = xnnpack_simd_copts_for_arch(arch),
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:isa_checks",
        "//:microkernels_h",
    ],
) for arch in xnnpack_simd_f32_archs()]

[xnnpack_unit_test(
    name = "f16_simd_" + arch + "_test",
    srcs = [
        "f16-simd-" + arch + ".cc",
    ],
    copts = xnnpack_simd_copts_for_arch(arch),
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:fp16",
        "//:isa_checks",
        "//:microkernels_h",
    ],
) for arch in xnnpack_simd_f16_archs()]

[xnnpack_unit_test(
    name = "s16_simd_" + arch + "_test",
    srcs = [
        "s16-simd-" + arch + ".cc",
    ],
    copts = xnnpack_simd_copts_for_arch(arch),
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:isa_checks",
        "//:microkernels_h",
    ],
) for arch in xnnpack_simd_s16_archs()]

[xnnpack_unit_test(
    name = "s32_simd_" + arch + "_test",
    srcs = [
        "s32-simd-" + arch + ".cc",
    ],
    copts = xnnpack_simd_copts_for_arch(arch),
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:isa_checks",
        "//:microkernels_h",
    ],
) for arch in xnnpack_simd_s32_archs()]

[xnnpack_unit_test(
    name = "s8_simd_" + arch + "_test",
    srcs = [
        "s8-simd-" + arch + ".cc",
    ],
    copts = xnnpack_simd_copts_for_arch(arch),
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:isa_checks",
        "//:microkernels_h",
    ],
) for arch in xnnpack_simd_s8_archs()]

[xnnpack_unit_test(
    name = "u32_simd_" + arch + "_test",
    srcs = [
        "u32-simd-" + arch + ".cc",
    ],
    copts = xnnpack_simd_copts_for_arch(arch),
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:isa_checks",
        "//:microkernels_h",
    ],
) for arch in xnnpack_simd_u32_archs()]

######################### Unit tests for micro-kernels #########################

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS + [":vunary_microkernel_tester"],
) for kernel in [
    "bf16_vabs",
    "f16_vabs",
    "f16_vclamp",
    "f16_velu",
    "f16_vhswish",
    "f16_vlrelu",
    "f16_vneg",
    "f16_vrndd",
    "f16_vrndne",
    "f16_vrndu",
    "f16_vrndz",
    "f16_vrsqrt",
    "f16_vsigmoid",
    "f16_vsqr",
    "f16_vsqrt",
    "f16_vtanh",
    "f32_vabs",
    "f32_vclamp",
    "f32_velu",
    "f32_vexp",
    "f32_vgelu",
    "f32_vhswish",
    "f32_vlog",
    "f32_vlrelu",
    "f32_vneg",
    "f32_vrelu",
    "f32_vrndd",
    "f32_vrndne",
    "f32_vrndu",
    "f32_vrndz",
    "f32_vrsqrt",
    "f32_vsigmoid",
    "f32_vsqr",
    "f32_vsqrt",
    "f32_vtanh",
    "s8_vclamp",
    "u8_vclamp",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS + [":vbinary_microkernel_tester"],
) for kernel in [
    "f16_vadd",
    "f16_vaddc",
    "f16_vdiv",
    "f16_vdivc",
    "f16_vmax",
    "f16_vmaxc",
    "f16_vmin",
    "f16_vminc",
    "f16_vmul",
    "f16_vmulc",
    "f16_vprelu",
    "f16_vpreluc",
    "f16_vrpreluc",
    "f16_vrdivc",
    "f16_vrsubc",
    "f16_vsqrdiff",
    "f16_vsqrdiffc",
    "f16_vsub",
    "f32_vadd",
    "f32_vaddc",
    "f32_vcopysign",
    "f32_vcopysignc",
    "f32_vdiv",
    "f32_vdivc",
    "f32_vmax",
    "f32_vmaxc",
    "f32_vmin",
    "f32_vminc",
    "f32_vmul",
    "f32_vmulc",
    "f32_vprelu",
    "f32_vpreluc",
    "f32_vrpreluc",
    "f32_vrcopysignc",
    "f32_vrdivc",
    "f32_vrsubc",
    "f32_vsqrdiff",
    "f32_vsqrdiffc",
    "f32_vsub",
    "f32_vsubc",
    "qs8_vadd_minmax",
    "qs8_vaddc_minmax",
    "qs8_vmul_minmax_fp32",
    "qs8_vmul_minmax_rndnu",
    "qs8_vmulc_minmax_fp32",
    "qs8_vmulc_minmax_rndnu",
    "qu8_vadd_minmax",
    "qu8_vaddc_minmax",
    "qu8_vmul_minmax_fp32",
    "qu8_vmul_minmax_rndnu",
    "qu8_vmulc_minmax_fp32",
    "qu8_vmulc_minmax_rndnu",
    "s32_vmul",
    "s32_vmulc",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS + [":vcvt_microkernel_tester"],
) for kernel in [
    "f16_f32_vcvt",
    "f16_qs8_vcvt",
    "f32_f16_vcvt",
    "f32_qs8_vcvt",
    "f32_qu8_vcvt",
    "s32_f32_vcvt",
    "u32_f32_vcvt",
    "qs16_qs8_vcvt",
    "qs8_f16_vcvt",
    "qs8_f32_vcvt",
    "qs8_vcvt",
    "qu8_vcvt",
    "qu8_f32_vcvt",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "reduce-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "f16_rmax",
    "f16_rmin",
    "f16_rminmax",
    "f32_rmax",
    "f32_rmin",
    "f32_rminmax",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "rsum-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "f16_rsum",
    "f16_f32acc_rsum",
    "f32_rsum",
    "qs8_rsum",
    "qu8_rsum",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "ibilinear-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "f32_ibilinear",
    "f32_ibilinear_chw",
    "f16_ibilinear",
    "f16_ibilinear_chw",
    "s8_ibilinear",
    "u8_ibilinear",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    timeout = "moderate",
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    shard_count = shard_count,
    deps = MICROKERNEL_TEST_DEPS + [
        ":dwconv_microkernel_tester",
    ],
) for (kernel, shard_count) in [
    ("f16_dwconv_minmax_multipass", 5),
    ("f16_dwconv_minmax_unipass", 5),
    ("f32_dwconv_unipass", 1),
    ("f32_dwconv_multipass", 5),
    ("f32_dwconv_minmax_unipass", 5),
    ("f32_dwconv_minmax_multipass", 5),
    ("qs8_qc8w_dwconv_minmax_multipass_fp32", 10),
    ("qs8_qc8w_dwconv_minmax_unipass_fp32", 10),
    ("qs8_dwconv_minmax_multipass_fp32", 10),
    ("qs8_dwconv_minmax_multipass_rndnu", 10),
    ("qs8_dwconv_minmax_unipass_fp32", 10),
    ("qs8_dwconv_minmax_unipass_rndnu", 1),
    ("qu8_dwconv_minmax_multipass_fp32", 10),
    ("qu8_dwconv_minmax_multipass_rndnu", 10),
    ("qu8_dwconv_minmax_unipass_fp32", 5),
    ("qu8_dwconv_minmax_unipass_rndnu", 1),
]]

xnnpack_unit_test(
    name = "maxpool_minmax_test",
    srcs = [
        "maxpool-microkernel-tester.h",
        "maxpool-minmax.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "bf16_gemm_minmax_test",
    srcs = [
        "bf16-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "avgpool_minmax_test",
    srcs = [
        "avgpool-microkernel-tester.h",
        "avgpool-minmax.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_gavgpool_cw_test",
    srcs = [
        "f16-gavgpool-cw.cc",
        "gavgpool-cw-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_gavgpool_minmax_test",
    srcs = [
        "f16-gavgpool-minmax.cc",
        "gavgpool-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_f32acc_gemm_minmax_test",
    srcs = [
        "f16-f32acc-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_gemm_minmax_test",
    srcs = [
        "f16-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_f32acc_igemm_minmax_test",
    srcs = [
        "f16-f32acc-igemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_igemm_minmax_test",
    srcs = [
        "f16-igemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_prelu_test",
    srcs = [
        "f16-prelu.cc",
        "prelu-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_spmm_minmax_test",
    srcs = [
        "f16-spmm-minmax.cc",
        "spmm-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_vmulcaddc_minmax_test",
    srcs = [
        "f16-vmulcaddc-minmax.cc",
        "vmulcaddc-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_raddstoreexpminusmax_test",
    srcs = [
        "f16-raddstoreexpminusmax.cc",
        "raddstoreexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_argmaxpool_test",
    srcs = [
        "argmaxpool-microkernel-tester.h",
        "f32-argmaxpool.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_igemm_test",
    srcs = [
        "f32-igemm.cc",
        "f32-igemm-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_igemm_relu_test",
    srcs = [
        "f32-igemm-relu.cc",
        "f32-igemm-relu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_igemm_minmax_test",
    srcs = [
        "f32-igemm-minmax.cc",
        "f32-igemm-minmax-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_conv_hwc_test",
    srcs = [
        "conv-hwc-microkernel-tester.h",
        "f32-conv-hwc.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_conv_hwc2chw_test",
    srcs = [
        "conv-hwc2chw-microkernel-tester.h",
        "f16-conv-hwc2chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_conv_hwc2chw_test",
    srcs = [
        "conv-hwc2chw-microkernel-tester.h",
        "f32-conv-hwc2chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_dwconv2d_chw_test",
    srcs = [
        "dwconv2d-microkernel-tester.h",
        "f16-dwconv2d-chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_dwconv2d_chw_test",
    srcs = [
        "dwconv2d-microkernel-tester.h",
        "f32-dwconv2d-chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_gavgpool_minmax_test",
    srcs = [
        "f32-gavgpool-minmax.cc",
        "gavgpool-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_gavgpool_cw_test",
    srcs = [
        "f32-gavgpool-cw.cc",
        "gavgpool-cw-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_gemm_test",
    srcs = [
        "f32-gemm.cc",
        "f32-gemm-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemm_relu_test",
    srcs = [
        "f32-gemm-relu.cc",
        "f32-gemm-relu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemm_minmax_test",
    srcs = [
        "f32-gemm-minmax.cc",
        "f32-gemm-minmax-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemm_goi_minmax_test",
    srcs = [
        "f32-gemm-goi-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc8w_gemm_test",
    srcs = [
        "f32-qc8w-gemm.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc8w_gemm_relu_test",
    srcs = [
        "f32-qc8w-gemm-relu.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc4w_gemm_minmax_test",
    srcs = [
        "f32-qc4w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc8w_gemm_minmax_test",
    srcs = [
        "f32-qc8w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemminc_minmax_test",
    srcs = [
        "f32-gemminc-minmax.cc",
        "f32-gemminc-minmax-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_ppmm_minmax_test",
    srcs = [
        "f32-ppmm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_prelu_test",
    srcs = [
        "f32-prelu.cc",
        "prelu-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_raddexpminusmax_test",
    srcs = [
        "f32-raddexpminusmax.cc",
        "raddexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_raddextexp_test",
    srcs = [
        "f32-raddextexp.cc",
        "raddextexp-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_raddstoreexpminusmax_test",
    srcs = [
        "f32-raddstoreexpminusmax.cc",
        "raddstoreexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_f32acc_rdsum_test",
    srcs = [
        "f16-f32acc-rdsum.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "f32_rdsum_test",
    srcs = [
        "f32-rdsum.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "f32_spmm_minmax_test",
    srcs = [
        "f32-spmm-minmax.cc",
        "f32-spmm-minmax-2.cc",
        "f32-spmm-minmax-3.cc",
        "f32-spmm-minmax-4.cc",
        "spmm-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_vcmul_test",
    srcs = [
        "f16-vcmul.cc",
        "vcmul-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vcmul_test",
    srcs = [
        "f32-vcmul.cc",
        "vcmul-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vmulcaddc_minmax_test",
    srcs = [
        "f32-vmulcaddc-minmax.cc",
        "vmulcaddc-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vscaleexpminusmax_test",
    srcs = [
        "f32-vscaleexpminusmax.cc",
        "vscaleexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vscaleextexp_test",
    srcs = [
        "f32-vscaleextexp.cc",
        "vscaleextexp-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qd8_f16_qc8w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qc8w-gemm-minmax.cc",
        "qd8-f16-qc8w-gemm-minmax-2.cc",
        "qd8-f16-qc8w-gemm-minmax-3.cc",
        "qd8-f16-qc8w-gemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qc8w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f32-qc8w-gemm-minmax.cc",
        "qd8-f32-qc8w-gemm-minmax-2.cc",
        "qd8-f32-qc8w-gemm-minmax-3.cc",
        "qd8-f32-qc8w-gemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f16_qc4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qc4w-gemm-minmax.cc",
        "qd8-f16-qc4w-gemm-minmax-2.cc",
        "qd8-f16-qc4w-gemm-minmax-3.cc",
        "qd8-f16-qc4w-gemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qc4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f32-qc4w-gemm-minmax.cc",
        "qd8-f32-qc4w-gemm-minmax-2.cc",
        "qd8-f32-qc4w-gemm-minmax-3.cc",
        "qd8-f32-qc4w-gemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f16_qb4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qb4w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qb4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f32-qb4w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qp8_f32_qc4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qp8-f32-qc4w-gemm-minmax.cc",
    ],
    defines = xnnpack_kleidiai_defines(),
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qs8_qc8w_gemm_minmax_fp32_test",
    timeout = "moderate",
    srcs = [
        "qs8-qc8w-gemm-minmax-fp32.cc",
        "qs8-qc8w-gemm-minmax-fp32-2.cc",
        "qs8-qc8w-gemm-minmax-fp32-3.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f16_qc8w_igemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qc8w-igemm-minmax.cc",
        "qd8-f16-qc8w-igemm-minmax-2.cc",
        "qd8-f16-qc8w-igemm-minmax-3.cc",
        "qd8-f16-qc8w-igemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qc8w_igemm_minmax_test",
    srcs = [
        "qd8-f32-qc8w-igemm-minmax.cc",
        "qd8-f32-qc8w-igemm-minmax-2.cc",
        "qd8-f32-qc8w-igemm-minmax-3.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qs8_qc8w_igemm_minmax_fp32_test",
    timeout = "moderate",
    srcs = [
        "qs8-qc8w-igemm-minmax-fp32.cc",
        "qs8-qc8w-igemm-minmax-fp32-2.cc",
        "qs8-qc8w-igemm-minmax-fp32-3.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qs8_gavgpool_minmax_fp32_test",
    srcs = [
        "gavgpool-microkernel-tester.h",
        "qs8-gavgpool-minmax-fp32.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qs8_gavgpool_minmax_rndnu_test",
    srcs = [
        "gavgpool-microkernel-tester.h",
        "qs8-gavgpool-minmax-rndnu.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qs8_requantization_test",
    srcs = [
        "qs8-requantization.cc",
        "requantization-tester.h",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS + ["//:requantization_stubs"],
)

xnnpack_unit_test(
    name = "qs8_rdsum_minmax_fp32_test",
    srcs = [
        "qs8-rdsum-minmax-fp32.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "qu8_rdsum_test",
    srcs = [
        "qu8-rdsum.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "qs8_vhswish_test",
    srcs = [
        "qs8-vhswish.cc",
        "vhswish-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qs8_vlrelu_test",
    srcs = [
        "qs8-vlrelu.cc",
        "vlrelu-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qu8_gavgpool_minmax_fp32_test",
    srcs = [
        "gavgpool-microkernel-tester.h",
        "qu8-gavgpool-minmax-fp32.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qu8_gavgpool_minmax_rndnu_test",
    srcs = [
        "gavgpool-microkernel-tester.h",
        "qu8-gavgpool-minmax-rndnu.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qu8_gemm_minmax_fp32_test",
    srcs = [
        "qu8-gemm-minmax-fp32.cc",
        "qu8-gemm-minmax-fp32-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_gemm_minmax_rndnu_test",
    srcs = [
        "qu8-gemm-minmax-rndnu.cc",
        "qu8-gemm-minmax-rndnu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_igemm_minmax_fp32_test",
    srcs = [
        "qu8-igemm-minmax-fp32.cc",
        "qu8-igemm-minmax-fp32-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_igemm_minmax_rndnu_test",
    srcs = [
        "qu8-igemm-minmax-rndnu.cc",
        "qu8-igemm-minmax-rndnu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_requantization_test",
    srcs = [
        "qu8-requantization.cc",
        "requantization-tester.h",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS + ["//:requantization_stubs"],
)

xnnpack_unit_test(
    name = "qu8_vhswish_test",
    srcs = [
        "qu8-vhswish.cc",
        "vhswish-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qu8_vlrelu_test",
    srcs = [
        "qu8-vlrelu.cc",
        "vlrelu-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "u8_lut32norm_test",
    srcs = [
        "lut-norm-microkernel-tester.h",
        "u8-lut32norm.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "u8_rmax_test",
    srcs = [
        "reduce-microkernel-tester.h",
        "u8-rmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x8_lut_test",
    srcs = [
        "lut-microkernel-tester.h",
        "x8-lut.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x8_zip_test",
    srcs = [
        "x8-zip.cc",
        "zip-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x32_packb_test",
    srcs = [
        "packb-microkernel-tester.h",
        "x32-packb.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x8_packq_test",
    srcs = [
        "x8-packq.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":packq_microkernel_tester",
    ],
)

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "packw-microkernel-tester.h",
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "x8_packw",
    "qs8_packw",
    "x16_packw",
    "x32_packw",
]]

xnnpack_unit_test(
    name = "x32_packx_test",
    srcs = [
        "pack-microkernel-tester.h",
        "x32-packx.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "xN_transpose_test",
    srcs = ["xN-transpose.cc"],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x32_unpool_test",
    srcs = [
        "unpool-microkernel-tester.h",
        "x32-unpool.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x32_zip_test",
    srcs = [
        "x32-zip.cc",
        "zip-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "xx_fill_test",
    srcs = ["xx-fill.cc"],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "xx_pad_test",
    srcs = ["xx-pad.cc"],
    deps = MICROKERNEL_TEST_DEPS,
)

########################## Size tests for the library #########################

xnnpack_binary(
    name = "operator_size_test",
    srcs = ["operator-size.c"],
    deps = ["//:XNNPACK"],
)

xnnpack_binary(
    name = "subgraph_size_test",
    srcs = ["subgraph-size.c"],
    deps = ["//:XNNPACK"],
)

########################### Unit tests for operators ##########################

[xnnpack_unit_test(
    name = "%s_test" % operator,
    srcs = [
        "%s.cc" % operator.replace("_", "-"),
    ],
    deps = OPERATOR_TEST_DEPS + [
        ":unary_operator_tester",
    ],
) for operator in [
    "abs_nc",
    "bankers_rounding_nc",
    "ceiling_nc",
    "clamp_nc",
    "elu_nc",
    "exp_nc",
    "floor_nc",
    "gelu_nc",
    "hardswish_nc",
    "leaky_relu_nc",
    "log_nc",
    "negate_nc",
    "reciprocal_square_root_nc",
    "sigmoid_nc",
    "square_nc",
    "square_root_nc",
    "tanh_nc",
    "truncation_nc",
]]

xnnpack_unit_test(
    name = "binary_elementwise_nd_test",
    timeout = "long",
    srcs = ["binary-elementwise-nd.cc"],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "argmax_pooling_nhwc_test",
    srcs = [
        "argmax-pooling-nhwc.cc",
        "argmax-pooling-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "average_pooling_nhwc_test",
    srcs = [
        "average-pooling-nhwc.cc",
        "average-pooling-operator-tester.h",
    ],
    shard_count = 5,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "batch_matrix_multiply_nc_test",
    timeout = "long",
    srcs = [
        "batch-matrix-multiply-nc.cc",
        "batch-matrix-multiply-operator-tester.h",
    ],
    shard_count = 2,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "channel_shuffle_nc_test",
    srcs = [
        "channel-shuffle-nc.cc",
        "channel-shuffle-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "constant_pad_nd_test",
    srcs = [
        "constant-pad-nd.cc",
        "constant-pad-operator-tester.h",
    ],
    shard_count = 5,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "constant_pad_nd_eager_test",
    srcs = [
        "constant-pad-nd-eager.cc",
        "constant-pad-operator-tester.h",
    ],
    shard_count = 5,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "convert_nc_test",
    timeout = "moderate",
    srcs = [
        "convert-nc.cc",
        "convert-operator-tester.h",
    ],
    shard_count = 5,
    deps = OPERATOR_TEST_DEPS + ["//:microkernels_h"],
)

xnnpack_unit_test(
    name = "convert_nc_eager_test",
    srcs = [
        "convert-nc-eager.cc",
        "convert-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS + ["//:microkernels_h"],
)

xnnpack_unit_test(
    name = "convolution_nhwc_test",
    timeout = "moderate",
    srcs = [
        "convolution-nhwc.cc",
        "convolution-operator-tester.h",
    ],
    shard_count = 10,
    deps = OPERATOR_TEST_DEPS + [
        ":convolution_test_helpers",
    ],
)

xnnpack_unit_test(
    name = "convolution_nchw_test",
    timeout = "moderate",
    srcs = [
        "convolution-nchw.cc",
        "convolution-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS + [
        ":convolution_test_helpers",
    ],
)

xnnpack_unit_test(
    name = "copy_nc_test",
    srcs = [
        "copy-nc.cc",
        "copy-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "copy_nc_eager_test",
    srcs = [
        "copy-nc-eager.cc",
        "copy-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "deconvolution_nhwc_test",
    timeout = "moderate",
    srcs = [
        "deconvolution-nhwc.cc",
        "deconvolution-nhwc-qd8-f32-qc8w.cc",
        "deconvolution-operator-tester.h",
    ],
    shard_count = 10,
    deps = OPERATOR_TEST_DEPS + [
        ":convolution_test_helpers",
    ],
)

xnnpack_unit_test(
    name = "depth_to_space_nchw2nhwc_test",
    srcs = [
        "depth-to-space-nchw2nhwc.cc",
        "depth-to-space-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "depth_to_space_nhwc_test",
    srcs = [
        "depth-to-space-nhwc.cc",
        "depth-to-space-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "dynamic_fully_connected_nc_test",
    srcs = [
        "dynamic-fully-connected-nc.cc",
        "dynamic-fully-connected-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "fully_connected_nc_test",
    srcs = [
        "fully-connected-nc.cc",
        "fully-connected-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS + [
        "//:microkernels_h",
    ],
)

xnnpack_unit_test(
    name = "global_average_pooling_nwc_test",
    timeout = "moderate",
    srcs = [
        "global-average-pooling-nwc.cc",
        "global-average-pooling-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "global_average_pooling_ncw_test",
    srcs = [
        "global-average-pooling-ncw.cc",
        "global-average-pooling-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "global_sum_pooling_nwc_test",
    srcs = [
        "global-sum-pooling-nwc.cc",
        "global-sum-pooling-operator-tester.h",
    ],
    shard_count = 10,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "max_pooling_nhwc_test",
    timeout = "moderate",
    srcs = [
        "max-pooling-nhwc.cc",
        "max-pooling-operator-tester.h",
    ],
    shard_count = 10,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "mean_nd_test",
    timeout = "moderate",
    srcs = [
        "mean-nd.cc",
        "mean-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS + ["//:requantization"],
)

xnnpack_unit_test(
    name = "slice_normalization_test",
    srcs = [
        "slice-normalization.cc",
        "slice-normalization-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "reduce_normalization_test",
    srcs = [
        "reduce-normalization.cc",
        "reduce-normalization-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "transpose_normalization_test",
    srcs = [
        "transpose-normalization.cc",
        "transpose-normalization-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "prelu_nc_test",
    srcs = [
        "prelu-nc.cc",
        "prelu-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "resize_bilinear_nhwc_test",
    srcs = [
        "resize-bilinear-nhwc.cc",
        "resize-bilinear-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "resize_bilinear_nchw_test",
    srcs = [
        "resize-bilinear-nchw.cc",
        "resize-bilinear-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "rope_nthc_test",
    srcs = [
        "rope-nthc.cc",
        "rope-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "scaled_dot_product_attention_nhtc_test",
    timeout = "moderate",
    srcs = [
        "scaled-dot-product-attention-nhtc.cc",
        "scaled-dot-product-attention-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "slice_nd_test",
    timeout = "moderate",
    srcs = [
        "slice-nd.cc",
        "slice-operator-tester.h",
    ],
    shard_count = 5,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "slice_nd_eager_test",
    timeout = "moderate",
    srcs = [
        "slice-nd-eager.cc",
        "slice-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "softmax_nc_test",
    srcs = [
        "softmax-nc.cc",
        "softmax-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "space_to_depth_nhwc_test",
    srcs = [
        "space-to-depth-nhwc.cc",
        "space-to-depth-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "transpose_nd_test",
    srcs = [
        "transpose-nd.cc",
        "transpose-operator-tester.h",
    ],
    shard_count = 10,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "transpose_nd_eager_test",
    srcs = [
        "transpose-nd-eager.cc",
        "transpose-operator-tester.h",
    ],
    shard_count = 5,
    deps = OPERATOR_TEST_DEPS,
)

xnnpack_unit_test(
    name = "unpooling_nhwc_test",
    srcs = [
        "unpooling-nhwc.cc",
        "unpooling-operator-tester.h",
    ],
    deps = OPERATOR_TEST_DEPS,
)

########################### Unit tests for subgraph ###########################

xnnpack_cxx_library(
    name = "convolution_test_helpers",
    testonly = True,
    srcs = [
        "convolution-test-helpers.cc",
    ],
    hdrs = [
        "convolution-test-helpers.h",
    ],
    deps = [
        "//:microparams",
        "//:quantization",
    ],
)

xnnpack_cxx_library(
    name = "subgraph_unary_tester",
    testonly = True,
    hdrs = [
        "subgraph-unary-tester.h",
    ],
    deps = xnnpack_test_deps_for_library() + [
        ":replicable_random_device",
        "//:node_type",
        "//:operators",
        "//:requantization",
        "//:subgraph",
        "//:xnnpack_h",
    ],
)

[xnnpack_unit_test(
    name = "%s_test" % operator,
    srcs = [
        "%s.cc" % operator.replace("_", "-"),
    ],
    deps = [
        ":replicable_random_device",
        ":subgraph_unary_tester",
        "//:XNNPACK",
        "//:math",
        "//:node_type",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
) for operator in [
    "abs",
    "bankers_rounding",
    "ceiling",
    "clamp",
    "convert",
    "copy",
    "elu",
    "exp",
    "floor",
    "gelu",
    "hardswish",
    "leaky_relu",
    "log",
    "negate",
    "reciprocal_square_root",
    "sigmoid",
    "softmax",
    "space_to_depth_2d",
    "square",
    "square_root",
    "static_constant_pad",
    "static_expand_dims",
    "static_reshape",
    "static_slice",
    "static_transpose",
    "tanh",
]]

xnnpack_cxx_library(
    name = "subgraph_binary_tester",
    testonly = True,
    hdrs = [
        "subgraph-binary-tester.h",
    ],
    deps = xnnpack_test_deps_for_library() + [
        ":replicable_random_device",
        "//:node_type",
        "//:operators",
        "//:requantization",
        "//:subgraph",
        "//:xnnpack_h",
    ],
)

xnnpack_unit_test(
    name = "workspace_test",
    srcs = [
        "workspace.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:allocation_type",
        "//:math",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "abs_reshape_test",
    srcs = [
        "abs-reshape.cc",
    ],
    deps = [
        "//:XNNPACK",
        "//:node_type",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "binary_test",
    srcs = ["binary.cc"],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:math",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "argmax_pooling_2d_test",
    srcs = [
        "argmax-pooling-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "average_pooling_2d_test",
    srcs = [
        "average-pooling-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operator_utils",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "average_pooling_2d_reshape_test",
    srcs = [
        "average-pooling-2d-reshape.cc",
    ],
    deps = [
        "//:XNNPACK",
        "//:node_type",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "batch_matrix_multiply_test",
    srcs = [
        "batch-matrix-multiply.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

[xnnpack_unit_test(
    name = "concatenate%d_test" % n,
    srcs = [
        "concatenate%d.cc" % n,
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
) for n in [
    2,
    3,
    4,
    5,
]]

xnnpack_unit_test(
    name = "convolution_2d_test",
    srcs = [
        "convolution-2d.cc",
    ],
    deps = [
        ":convolution_test_helpers",
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operator_utils",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "deconvolution_2d_test",
    timeout = "moderate",
    srcs = [
        "deconvolution-2d.cc",
    ],
    shard_count = 5,
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operator_utils",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "depth_to_space_2d_test",
    srcs = [
        "depth-to-space-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "depthwise_convolution_2d_test",
    srcs = [
        "depthwise-convolution-2d.cc",
    ],
    deps = [
        ":convolution_test_helpers",
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operator_utils",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

[xnnpack_unit_test(
    name = "even_split%d_test" % n,
    srcs = [
        "even-split%s.cc" % n,
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
) for n in [
    2,
    3,
    4,
]]

xnnpack_unit_test(
    name = "fully_connected_test",
    timeout = "moderate",
    srcs = [
        "fully-connected.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:config_hdrs",
        "//:internal",
        "//:math",
        "//:microkernels_h",
        "//:node_type",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "global_average_pooling_1d_test",
    srcs = [
        "global-average-pooling-1d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "global_average_pooling_2d_test",
    srcs = [
        "global-average-pooling-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "global_sum_pooling_1d_test",
    srcs = [
        "global-sum-pooling-1d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "global_sum_pooling_2d_test",
    srcs = [
        "global-sum-pooling-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "max_pooling_2d_test",
    srcs = [
        "max-pooling-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operator_utils",
        "//:operators",
        "//:requantization",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "prelu_test",
    srcs = [
        "prelu.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "rope_test",
    srcs = [
        "rope.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:math",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "scaled_dot_product_attention_test",
    srcs = [
        "scaled-dot-product-attention.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "static_mean_test",
    srcs = [
        "static-mean.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:aligned_allocator",
        "//:common",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "reshape_helpers_test",
    srcs = [
        "reshape-helpers.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "static_resize_bilinear_2d_test",
    srcs = [
        "static-resize-bilinear-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "transpose_reshape_test",
    srcs = [
        "transpose-reshape.cc",
    ],
    deps = [
        "//:XNNPACK",
        "//:node_type",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "unpooling_2d_test",
    srcs = [
        "unpooling-2d.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:operator_utils",
        "//:operators",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "fusion_test",
    srcs = [
        "fusion.cc",
        "runtime-tester.h",
        "subgraph-tester.h",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:subgraph",
    ],
)

############################### Misc unit tests ###############################

xnnpack_unit_test(
    name = "runtime_test",
    srcs = [
        "runtime.cc",
        "runtime-tester.h",
        "subgraph-tester.h",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "subgraph_test",
    srcs = [
        "runtime-tester.h",
        "subgraph.cc",
        "subgraph-tester.h",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "memory_planner_test",
    srcs = [
        "memory-planner.cc",
        "runtime-tester.h",
        "subgraph-tester.h",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:node_type",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "subgraph_nchw_test",
    srcs = [
        "subgraph-nchw.cc",
        "subgraph-tester.h",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "subgraph_fp16_test",
    srcs = [
        "mock-allocator.h",
        "runtime-tester.h",
        "subgraph-fp16.cc",
        "subgraph-tester.h",
    ],
    deps = [
        ":replicable_random_device",
        "//:XNNPACK",
        "//:allocation_type",
        "//:allocator",
        "//:node_type",
        "//:params",
        "//:subgraph",
    ],
)

xnnpack_unit_test(
    name = "weights_cache_test",
    srcs = ["weights-cache.cc"],
    deps = [
        "//:XNNPACK",
        "//:cache",
        "//:common",
        "//:memory",
    ],
)

xnnpack_unit_test(
    name = "mutex_test",
    srcs = ["mutex.cc"],
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:mutex",
        "//:xnnpack_h",
    ],
)

xnnpack_unit_test(
    name = "microkernel_utils_test",
    srcs = ["microkernel-utils.cc"],
    deps = ["//:microkernel_utils"],
)

xnnpack_unit_test(
    name = "operator_utils_test",
    srcs = ["operator-utils.cc"],
    deps = [
        "//:microkernel_configs",
        "//:operator_utils",
    ],
)

xnnpack_unit_test(
    name = "packing_test",
    srcs = [
        "packing.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        "//:microkernel_utils",
        "//:operator_utils",
    ],
)

xnnpack_unit_test(
    name = "indirection_test",
    srcs = [
        "indirection.cc",
    ],
    deps = [
        "//:indirection",
        "//:math",
        "//:operator_utils",
        "//:operators",
        "//:xnnpack_h",
    ],
)

xnnpack_unit_test(
    name = "build_identifier_test",
    srcs = [
        "build-identifier.cc",
    ],
    deps = [
        "//:XNNPACK",
    ],
)
